{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Regression - Multicomponent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training_regression_multicomponent.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from lightning import pytorch as pl\n",
    "from pathlib import Path\n",
    "\n",
    "from chemprop import data, featurizers, models, nn\n",
    "from chemprop.nn import metrics\n",
    "from chemprop.models import multi\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Change your data inputs here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol+mol\" / \"mol+mol.csv\" # path to your data .csv file containing SMILES strings and target values\n",
    "smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings\n",
    "target_columns = ['peakwavs_max'] # list of names of the columns containing targets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>solvent</th>\n",
       "      <th>peakwavs_max</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>642.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>420.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...</td>\n",
       "      <td>O</td>\n",
       "      <td>544.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>c1ccc2[nH]ccc2c1</td>\n",
       "      <td>O</td>\n",
       "      <td>290.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...</td>\n",
       "      <td>ClC(Cl)Cl</td>\n",
       "      <td>736.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...</td>\n",
       "      <td>C1CCOC1</td>\n",
       "      <td>359.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...</td>\n",
       "      <td>C1CCCCC1</td>\n",
       "      <td>386.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O</td>\n",
       "      <td>CCO</td>\n",
       "      <td>425.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...</td>\n",
       "      <td>c1ccccc1</td>\n",
       "      <td>324.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...</td>\n",
       "      <td>ClCCl</td>\n",
       "      <td>391.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles    solvent  peakwavs_max\n",
       "0   CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...      ClCCl         642.0\n",
       "1   C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...      ClCCl         420.0\n",
       "2   CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...          O         544.0\n",
       "3                                    c1ccc2[nH]ccc2c1          O         290.0\n",
       "4   CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...  ClC(Cl)Cl         736.0\n",
       "..                                                ...        ...           ...\n",
       "95  COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...    C1CCOC1         359.0\n",
       "96  COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...   C1CCCCC1         386.0\n",
       "97          CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O        CCO         425.0\n",
       "98  Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...   c1ccccc1         324.0\n",
       "99  Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...      ClCCl         391.0\n",
       "\n",
       "[100 rows x 3 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "df_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get SMILES and targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "smiss = df_input.loc[:, smiles_columns].values\n",
    "ys = df_input.loc[:, target_columns].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',\n",
       "         'ClCCl'],\n",
       "        ['C(=C/c1cnccn1)\\\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',\n",
       "         'ClCCl'],\n",
       "        ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',\n",
       "         'O'],\n",
       "        ['c1ccc2[nH]ccc2c1', 'O'],\n",
       "        ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',\n",
       "         'ClC(Cl)Cl']], dtype=object),\n",
       " array([[642.],\n",
       "        [420.],\n",
       "        [544.],\n",
       "        [290.],\n",
       "        [736.]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Take a look at the first 5 SMILES strings and targets\n",
    "smiss[:5], ys[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make molecule datapoints\n",
    "Create a list of lists containing the molecule datapoints for each components. The target is stored in the 0th component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]\n",
    "all_data += [[data.MoleculeDatapoint.from_smi(smis[i]) for smis in smiss] for i in range(1, len(smiles_columns))]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Split data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform data splitting for training, validation, and testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "component_to_split_by = 0 # index of the component to use for structure based splits\n",
    "mols = [d.mol for d in all_data[component_to_split_by]]\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get MoleculeDataset for each components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n",
    "\n",
    "train_datasets = [data.MoleculeDataset(train_data[0][i], featurizer) for i in range(len(smiles_columns))]\n",
    "val_datasets = [data.MoleculeDataset(val_data[0][i], featurizer) for i in range(len(smiles_columns))]\n",
    "test_datasets = [data.MoleculeDataset(test_data[0][i], featurizer) for i in range(len(smiles_columns))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct multicomponent dataset and scale the targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_mcdset = data.MulticomponentDataset(train_datasets)\n",
    "scaler = train_mcdset.normalize_targets()\n",
    "val_mcdset = data.MulticomponentDataset(val_datasets)\n",
    "val_mcdset.normalize_targets(scaler)\n",
    "test_mcdset = data.MulticomponentDataset(test_datasets)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = data.build_dataloader(train_mcdset)\n",
    "val_loader = data.build_dataloader(val_mcdset, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_mcdset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construct multicomponent MPNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MulticomponentMessagePassing\n",
    "- `blocks`: a list of message passing block used for each components\n",
    "- `n_components`: number of components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "mcmp = nn.MulticomponentMessagePassing(\n",
    "    blocks=[nn.BondMessagePassing() for _ in range(len(smiles_columns))],\n",
    "    n_components=len(smiles_columns),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = nn.MeanAggregation()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RegressionFFN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "ffn = nn.RegressionFFN(\n",
    "    input_dim=mcmp.output_dim,\n",
    "    output_transform=output_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_list = [metrics.RMSE(), metrics.MAE()] # Only the first metric is used for training and early stopping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MulticomponentMPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MulticomponentMPNN(\n",
       "  (message_passing): MulticomponentMessagePassing(\n",
       "    (blocks): ModuleList(\n",
       "      (0-1): 2 x BondMessagePassing(\n",
       "        (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
       "        (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "        (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
       "        (dropout): Dropout(p=0.0, inplace=False)\n",
       "        (tau): ReLU()\n",
       "        (V_d_transform): Identity()\n",
       "        (graph_transform): Identity()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): Identity()\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=600, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0): RMSE(task_weights=[[1.0]])\n",
       "    (1): MAE(task_weights=[[1.0]])\n",
       "    (2): MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mcmpnn = multi.MulticomponentMPNN(\n",
    "    mcmp,\n",
    "    agg,\n",
    "    ffn,\n",
    "    metrics=metric_list,\n",
    ")\n",
    "\n",
    "mcmpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set up trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_checkpointing=True,\n",
    "    enable_progress_bar=True,\n",
    "    accelerator=\"auto\",\n",
    "    devices=1,\n",
    "    max_epochs=20, # number of epochs to train for\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type                         | Params | Mode \n",
      "-------------------------------------------------------------------------\n",
      "0 | message_passing | MulticomponentMessagePassing | 455 K  | train\n",
      "1 | agg             | MeanAggregation              | 0      | train\n",
      "2 | bn              | Identity                     | 0      | train\n",
      "3 | predictor       | RegressionFFN                | 180 K  | train\n",
      "4 | X_d_transform   | Identity                     | 0      | train\n",
      "5 | metrics         | ModuleList                   | 0      | train\n",
      "-------------------------------------------------------------------------\n",
      "636 K     Trainable params\n",
      "0         Non-trainable params\n",
      "636 K     Total params\n",
      "2.544     Total estimated model params size (MB)\n",
      "35        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.50it/s, train_loss_step=0.422, val_loss=0.301, train_loss_epoch=0.659]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.10it/s, train_loss_step=0.422, val_loss=0.301, train_loss_epoch=0.659]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(mcmpnn, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 28.64it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     87.1765365600586      </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    105.41293334960938     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    87.1765365600586     \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   105.41293334960938    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(mcmpnn, test_loader)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
